-
Notifications
You must be signed in to change notification settings - Fork 270
Add partial support for from_protobuf #14062
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Haoyang Li <[email protected]>
|
@greptile full review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds partial GPU support for Spark's from_protobuf function, enabling decoding of binary protobuf data into Spark SQL structs. The implementation is intentionally limited to simple scalar types (boolean, int32, int64, float, double, string) and targets Spark 3.4.0+, where spark-protobuf is available as an optional external module.
Key changes:
- GPU expression implementation for protobuf decoding with simple types only
- Reflection-based shim layer to optionally register protobuf expressions when spark-protobuf module is available
- Build configuration updates to optionally include spark-protobuf JAR for integration testing
- Python integration tests with custom data generators for protobuf message encoding
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 11 comments.
Show a summary per file
| File | Description |
|---|---|
| sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/Spark340PlusNonDBShims.scala | Registers protobuf expression rules in the expression mapping |
| sql-plugin/src/main/spark340/scala/com/nvidia/spark/rapids/shims/ProtobufExprShims.scala | Implements reflection-based shim for ProtobufDataToCatalyst with GPU fallback rules and validation |
| sql-plugin/src/main/scala/org/apache/spark/sql/rapids/GpuFromProtobufSimple.scala | GPU implementation for from_protobuf decoding with null-intolerant behavior |
| sql-plugin/src/main/scala/com/nvidia/spark/rapids/protobuf/ProtobufDescriptorUtils.scala | Utility for parsing protobuf FileDescriptorSet (currently unused) |
| pom.xml | Adds maven property to control spark-protobuf dependency inclusion for Spark 3.4+ profiles |
| integration_tests/pom.xml | Configures maven to copy spark-protobuf JAR to dependency directory for tests |
| integration_tests/src/main/python/protobuf_test.py | Integration tests for from_protobuf with parquet round-trip and null input handling |
| integration_tests/src/main/python/data_gen.py | Protobuf message encoder and test data generator for simple scalar types |
| integration_tests/run_pyspark_from_build.sh | Updates test runner to conditionally include spark-protobuf JAR on classpath |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
|
|
||
|
|
Copilot
AI
Dec 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the trailing blank lines at the end of the file. This follows standard code style guidelines and maintains consistency across the codebase.
|
|
||
|
|
Copilot
AI
Dec 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the trailing blank lines at the end of the file. This follows standard code style guidelines and maintains consistency across the codebase.
|
|
||
|
|
Copilot
AI
Dec 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the trailing blank lines at the end of the file. This follows standard code style guidelines and maintains consistency across the codebase.
integration_tests/pom.xml
Outdated
|
|
||
| </project> |
Copilot
AI
Dec 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the trailing blank line at the end of the file. This follows standard code style guidelines and maintains consistency across the codebase.
|
|
||
|
|
Copilot
AI
Dec 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the trailing blank line at the end of the file. This follows standard code style guidelines and maintains consistency across the codebase.
| # if from_protobuf is None: | ||
| # pytest.skip("pyspark.sql.protobuf.functions.from_protobuf is not available") | ||
| # if not with_cpu_session(lambda spark: _spark_protobuf_jvm_available(spark)): | ||
| # pytest.skip("spark-protobuf JVM module is not available on the classpath") |
Copilot
AI
Dec 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These commented-out lines should either be removed or uncommented if the checks are needed. Leaving commented-out code in production reduces maintainability. If these checks are intentionally disabled for this WIP PR, consider adding a TODO comment explaining why and when they should be re-enabled.
| # if from_protobuf is None: | |
| # pytest.skip("pyspark.sql.protobuf.functions.from_protobuf is not available") | |
| # if not with_cpu_session(lambda spark: _spark_protobuf_jvm_available(spark)): | |
| # pytest.skip("spark-protobuf JVM module is not available on the classpath") | |
| if from_protobuf is None: | |
| pytest.skip("pyspark.sql.protobuf.functions.from_protobuf is not available") | |
| if not with_cpu_session(lambda spark: _spark_protobuf_jvm_available(spark)): | |
| pytest.skip("spark-protobuf JVM module is not available on the classpath") |
| message_name = "test.Simple" | ||
|
|
||
| # Generate descriptor bytes once using the JVM (no protoc dependency) | ||
| desc_bytes = with_cpu_session(lambda spark: _build_simple_descriptor_set_bytes(spark)) |
Copilot
AI
Dec 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This 'lambda' is just a simple wrapper around a callable object. Use that object directly.
| desc_bytes = with_cpu_session(lambda spark: _build_simple_descriptor_set_bytes(spark)) | |
| desc_bytes = with_cpu_session(_build_simple_descriptor_set_bytes) |
| message_name = "test.Simple" | ||
|
|
||
| # Generate descriptor bytes once using the JVM (no protoc dependency) | ||
| desc_bytes = with_cpu_session(lambda spark: _build_simple_descriptor_set_bytes(spark)) |
Copilot
AI
Dec 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This 'lambda' is just a simple wrapper around a callable object. Use that object directly.
| desc_bytes = with_cpu_session(lambda spark: _build_simple_descriptor_set_bytes(spark)) | |
| desc_bytes = with_cpu_session(_build_simple_descriptor_set_bytes) |
| raise ValueError("Unsupported type for protobuf simple generator: {}".format(spark_type)) | ||
|
|
||
|
|
||
| class ProtobufSimpleMessageRowGen(DataGen): |
Copilot
AI
Dec 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The class 'ProtobufSimpleMessageRowGen' does not override 'eq', but adds the new attribute _fields.
The class 'ProtobufSimpleMessageRowGen' does not override 'eq', but adds the new attribute _binary_col_name.
| # leaving syntax unset is sufficient/compatible. | ||
| try: | ||
| fd = fd.setSyntax("proto2") | ||
| except Exception: |
Copilot
AI
Dec 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
'except' clause does nothing but pass and there is no explanatory comment.
| except Exception: | |
| except Exception: | |
| # If setSyntax is unavailable (older protobuf-java), we intentionally leave syntax unset. |
Greptile OverviewGreptile SummaryThis PR adds partial GPU support for Spark's Key ChangesCore Implementation:
Shim Layer Integration:
Test Infrastructure:
Architecture Decisions
Limitations (Expected for "Partial Support")
Issues FoundStyle Issue: StructType matching in schema projection doesn't check nullable flags, which could cause false positive matches when identifying protobuf output references (line 446-451 in ProtobufExprShims.scala) Confidence Score: 4/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant User
participant Spark SQL
participant ProtobufDataToCatalyst
participant ProtobufExprShims
participant GpuFromProtobuf
participant JNI
participant cuDF
User->>Spark SQL: from_protobuf(col, message_name, desc_path)
Spark SQL->>ProtobufDataToCatalyst: Create expression
Note over ProtobufExprShims: GPU Override Registration
ProtobufExprShims->>ProtobufExprShims: tagExprForGpu()
ProtobufExprShims->>ProtobufExprShims: Resolve descriptor via reflection
ProtobufExprShims->>ProtobufExprShims: Analyze all fields (types, encoding)
ProtobufExprShims->>ProtobufExprShims: Detect required fields (schema projection)
ProtobufExprShims->>ProtobufExprShims: Check all required fields supported
alt All required fields supported
ProtobufExprShims->>GpuFromProtobuf: convertToGpu()
GpuFromProtobuf->>JNI: Protobuf.decodeToStruct()
JNI->>cuDF: Decode protobuf binary to struct
cuDF-->>JNI: Decoded struct (required fields only)
JNI-->>GpuFromProtobuf: Column vector
GpuFromProtobuf->>GpuFromProtobuf: Build full schema struct
GpuFromProtobuf->>GpuFromProtobuf: Fill non-decoded fields with nulls
GpuFromProtobuf->>GpuFromProtobuf: Apply input nulls
GpuFromProtobuf-->>Spark SQL: Result column
else Unsupported fields required
ProtobufExprShims->>Spark SQL: Fall back to CPU
Spark SQL->>ProtobufDataToCatalyst: Execute on CPU
ProtobufDataToCatalyst-->>Spark SQL: Result column
end
Spark SQL-->>User: DataFrame with decoded data
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
integration_tests/src/main/python/protobuf_test.py, line 133-136 (link)syntax: commented-out test skips will cause tests to fail if spark-protobuf is not available
9 files reviewed, 1 comment
Signed-off-by: Haoyang Li <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 10 out of 10 changed files in this pull request and generated 7 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| val (tid, _) = GpuFromProtobufSimple.sparkTypeToCudfId(sf.dataType) | ||
| typeIds(idx) = tid | ||
| scales(idx) = encoding.get |
Copilot
AI
Dec 25, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The second return value from sparkTypeToCudfId (the encoding) is being ignored and replaced with the encoding determined from the protobuf type. This suggests that sparkTypeToCudfId should not return an encoding at all, or its second return value should be used. Consider refactoring sparkTypeToCudfId to only return the type ID, since the encoding is determined elsewhere based on the protobuf field type.
| val (tid, _) = GpuFromProtobufSimple.sparkTypeToCudfId(sf.dataType) | |
| typeIds(idx) = tid | |
| scales(idx) = encoding.get | |
| val (tid, scale) = GpuFromProtobufSimple.sparkTypeToCudfId(sf.dataType) | |
| typeIds(idx) = tid | |
| if (scale != encoding.get) { | |
| willNotWorkOnGpu( | |
| s"Encoding mismatch for '${sf.name}': Spark ${sf.dataType} vs " + | |
| s"Protobuf $protoTypeName") | |
| return | |
| } | |
| scales(idx) = scale |
| def sparkTypeToCudfId(dt: DataType): (Int, Int) = dt match { | ||
| case BooleanType => (DType.BOOL8.getTypeId.getNativeId, ENC_DEFAULT) | ||
| case IntegerType => (DType.INT32.getTypeId.getNativeId, ENC_DEFAULT) | ||
| case LongType => (DType.INT64.getTypeId.getNativeId, ENC_DEFAULT) | ||
| case FloatType => (DType.FLOAT32.getTypeId.getNativeId, ENC_DEFAULT) | ||
| case DoubleType => (DType.FLOAT64.getTypeId.getNativeId, ENC_DEFAULT) | ||
| case StringType => (DType.STRING.getTypeId.getNativeId, ENC_DEFAULT) | ||
| case BinaryType => (DType.LIST.getTypeId.getNativeId, ENC_DEFAULT) | ||
| case other => | ||
| throw new IllegalArgumentException(s"Unsupported Spark type for protobuf(simple): $other") | ||
| } |
Copilot
AI
Dec 25, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The sparkTypeToCudfId function always returns ENC_DEFAULT as the second tuple element, regardless of the input type. However, this encoding value is not actually used by the caller in ProtobufExprShims.scala (line 199), which discards it with an underscore. Consider simplifying this function to return only the type ID as an Int, since the encoding is determined separately based on the protobuf wire type.
| object ProtobufDescriptorUtils { | ||
|
|
||
| def buildMessageDescriptor( | ||
| fileDescriptorSetBytes: Array[Byte], | ||
| messageName: String): Descriptors.Descriptor = { | ||
| val fds = DescriptorProtos.FileDescriptorSet.parseFrom(fileDescriptorSetBytes) | ||
| val protos = fds.getFileList.asScala.toSeq | ||
| val byName = protos.map(p => p.getName -> p).toMap | ||
| val cache = mutable.HashMap.empty[String, Descriptors.FileDescriptor] | ||
|
|
||
| def buildFileDescriptor(name: String): Descriptors.FileDescriptor = { | ||
| cache.getOrElseUpdate(name, { | ||
| val p = byName.getOrElse(name, | ||
| throw new IllegalArgumentException(s"Missing FileDescriptorProto for '$name'")) | ||
| val deps = p.getDependencyList.asScala.map(buildFileDescriptor _).toArray | ||
| Descriptors.FileDescriptor.buildFrom(p, deps) | ||
| }) | ||
| } | ||
|
|
||
| val fileDescriptors = protos.map(p => buildFileDescriptor(p.getName)) | ||
| val candidates = fileDescriptors.iterator.flatMap(fd => findMessageDescriptors(fd, messageName)) | ||
| .toSeq | ||
|
|
||
| candidates match { | ||
| case Seq(d) => d | ||
| case Seq() => | ||
| throw new IllegalArgumentException( | ||
| s"Message '$messageName' not found in FileDescriptorSet") | ||
| case many => | ||
| val names = many.map(_.getFullName).distinct.sorted | ||
| throw new IllegalArgumentException( | ||
| s"Message '$messageName' is ambiguous; matches: ${names.mkString(", ")}") | ||
| } | ||
| } | ||
|
|
||
| private def findMessageDescriptors( | ||
| fd: Descriptors.FileDescriptor, | ||
| messageName: String): Iterator[Descriptors.Descriptor] = { | ||
| def matches(d: Descriptors.Descriptor): Boolean = { | ||
| d.getName == messageName || | ||
| d.getFullName == messageName || | ||
| d.getFullName.endsWith("." + messageName) | ||
| } | ||
|
|
||
| def walk(d: Descriptors.Descriptor): Iterator[Descriptors.Descriptor] = { | ||
| val nested = d.getNestedTypes.asScala.iterator.flatMap(walk _) | ||
| if (matches(d)) Iterator.single(d) ++ nested else nested | ||
| } | ||
|
|
||
| fd.getMessageTypes.asScala.iterator.flatMap(walk _) | ||
| } | ||
| } |
Copilot
AI
Dec 25, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This utility class appears to be unused in the current implementation. The ProtobufExprShims.scala uses Spark's ProtobufUtils via reflection (buildMessageDescriptorWithSparkProtobuf) instead of this custom utility. Consider removing this file if it's not needed for future work, or add documentation explaining its intended purpose if it's meant for upcoming features.
| return | ||
| # Add driver-class-path for each jar | ||
| jar_list = jars.replace(',', ' ').split() | ||
| driver_cp = ':'.join(jar_list) |
Copilot
AI
Dec 25, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The classpath separator is hardcoded to colon (:) which is Unix-specific. On Windows, the classpath separator should be semicolon (;). Consider using os.pathsep instead of ':' to make this code platform-independent.
| driver_cp = ':'.join(jar_list) | |
| driver_cp = os.pathsep.join(jar_list) |
| - one column per message field (Spark scalar types) | ||
| - a binary column containing a serialized protobuf message containing those fields | ||
| This is intentionally limited to the simple scalar types supported in Patch 1: |
Copilot
AI
Dec 25, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment refers to "Patch 1" which may be unclear to readers who are not familiar with the development context. Consider rephrasing to something like "initial implementation" or "this PR" for better clarity.
| This is intentionally limited to the simple scalar types supported in Patch 1: | |
| This is intentionally limited to the simple scalar types currently supported by this implementation: |
| private def writeTempDescFile(descBytes: Array[Byte]): String = { | ||
| val tmp: Path = Files.createTempFile("spark-rapids-protobuf-desc-", ".desc") | ||
| Files.write(tmp, descBytes) | ||
| tmp.toFile.deleteOnExit() | ||
| tmp.toString |
Copilot
AI
Dec 25, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The temporary descriptor file is marked for deletion on JVM exit using deleteOnExit(), but there's no guarantee this will be called, especially in cases of abrupt termination or long-running processes. Consider using a try-finally block or resource management approach to ensure cleanup, or document that these temp files are expected to persist in some scenarios.
| from_protobuf = _try_import_from_protobuf() | ||
| if from_protobuf is None: | ||
| pytest.skip("pyspark.sql.protobuf.functions.from_protobuf is not available") | ||
| if not with_cpu_session(lambda spark: _spark_protobuf_jvm_available(spark)): |
Copilot
AI
Dec 25, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This 'lambda' is just a simple wrapper around a callable object. Use that object directly.
| if not with_cpu_session(lambda spark: _spark_protobuf_jvm_available(spark)): | |
| if not with_cpu_session(_spark_protobuf_jvm_available): |
Signed-off-by: Haoyang Li <[email protected]>
Signed-off-by: Haoyang Li <[email protected]>
Signed-off-by: Haoyang Li <[email protected]>
…spark-rapids into from_protobuffer_v0
Signed-off-by: Haoyang Li <[email protected]>
|
@greptileai full review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, 2 comments
| try { | ||
| for (i <- fullSchema.fields.indices) { | ||
| if (decodedIdx < decodedFieldIndices.length && decodedFieldIndices(decodedIdx) == i) { | ||
| // This field was decoded - extract from decoded struct | ||
| fullChildren(i) = decodedStruct.getChildColumnView(decodedIdx).copyToColumnVector() | ||
| decodedIdx += 1 | ||
| } else { | ||
| // This field was not decoded - create null column | ||
| fullChildren(i) = GpuFromProtobuf.createNullColumn( | ||
| fullSchema.fields(i).dataType, numRows) | ||
| } | ||
| } | ||
| cudf.ColumnVector.makeStruct(numRows, fullChildren: _*) | ||
| } finally { | ||
| fullChildren.foreach(c => if (c != null) c.close()) | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If an exception occurs during the loop (lines 85-95), the columns that were already created in fullChildren will be leaked because the finally block only closes non-null entries, but an exception could leave fullChildren partially populated. The finally block should be inside the loop body or use a more robust pattern.
Additionally, if cudf.ColumnVector.makeStruct on line 96 throws an exception, the columns in fullChildren will be leaked because they've already been closed in the finally block, but makeStruct hasn't taken ownership yet.
The safer pattern would be:
try {
for (i <- fullSchema.fields.indices) {
val child = try {
if (decodedIdx < decodedFieldIndices.length && decodedFieldIndices(decodedIdx) == i) {
val col = decodedStruct.getChildColumnView(decodedIdx).copyToColumnVector()
decodedIdx += 1
col
} else {
GpuFromProtobuf.createNullColumn(fullSchema.fields(i).dataType, numRows)
}
} catch {
case e: Throwable =>
fullChildren.take(i).foreach(c => if (c != null) c.close())
throw e
}
fullChildren(i) = child
}
cudf.ColumnVector.makeStruct(numRows, fullChildren: _*)
} finally {
fullChildren.foreach(c => if (c != null) c.close())
}Or use withResource pattern more carefully to ensure proper cleanup ordering.
| case st: StructType => | ||
| // Compare field names and types - StructType equality can be tricky | ||
| st.fields.length == fullSchema.fields.length && | ||
| st.fields.zip(fullSchema.fields).forall { case (a, b) => | ||
| a.name == b.name && a.dataType == b.dataType | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The StructType comparison doesn't account for nullable flags on struct fields, which could lead to false positive matches. Two structs with the same field names and types but different nullability should not be considered equal for this matching purpose.
Consider:
| case st: StructType => | |
| // Compare field names and types - StructType equality can be tricky | |
| st.fields.length == fullSchema.fields.length && | |
| st.fields.zip(fullSchema.fields).forall { case (a, b) => | |
| a.name == b.name && a.dataType == b.dataType | |
| } | |
| case st: StructType => | |
| // Compare field names, types, and nullable flags | |
| st.fields.length == fullSchema.fields.length && | |
| st.fields.zip(fullSchema.fields).forall { case (a, b) => | |
| a.name == b.name && a.dataType == b.dataType && a.nullable == b.nullable | |
| } |
Signed-off-by: Haoyang Li <[email protected]>
Contributes to #14069
Depends on NVIDIA/spark-rapids-jni#4107
Description
WIP
For discussions and AI review.
This pr adds a partial support for from_protobuf, with a limited feature and framework code.
Checklists
(Please explain in the PR description how the new code paths are tested, such as names of the new/existing tests that cover them.)